import torch


class Mutate(torch.nn.Module):
    def __init__(self, mutation_rate):
        super().__init__()
        self.mutation_rate = mutation_rate

    def forward(self, image):
        """Add mutations to input.

        Generate mutations for all positions,
        in order to be different than itselves, the mutations have to be >= 1
        mute the untargeted positions by multiple mask (1 for targeted)
        then add the mutations to the original, mod 255 if necessary.
        Args:
            image: input image tensor of size batch*width*height*channel
        Returns:
            mutated input
        """
        c, w, h = image.size()
        p = self.mutation_rate * torch.ones((c, w, h))
        q = 1 - p

        # mask has a prability q = (1 - mu) of getting a 1
        mask = torch.bernoulli(q)
        possible_mutations = torch.rand((c, h, w))

        # If mask = 1 pass the input i.e. pass the input with a prob 1 - mu
        # If mask = 0 sample from possible muations
        x = image * mask + (1 - mask) * possible_mutations
        return x
